from langchain.agents import AgentType, initialize_agent
from langchain_community.agent_toolkits import SQLDatabaseToolkit

from src.ai.langchain.common.mysql_conn import get_mysql_db
from src.ai.langchain.init_llm import get_llm

llm = get_llm()

db = get_mysql_db()

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

agent_executor = initialize_agent(
    llm=llm,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    tools=toolkit.get_tools(),
    verbose=True
)

question = "用户表中一共有多少数据？"
response = agent_executor.run(question)
print("回答结果：", response)